#!/usr/bin/env python

# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A simple changelog generator

This tool takes list of release versions to generate `CHANGELOG.md`.
The changelog will include all merged PRs w/o `[bot]` postfix,
and issues that include the labels `bug`, `feature request`, `SQL`, `performance`, `shuffle`,
minus any issues with the labels `wontfix`, `invalid` or `duplicate`.

For each project there should be an issue subsection for,
Features: all issues with label `feature request` + `SQL`
Performance: all issues with label `performance` + `shuffle`
Bugs fixed: all issues with label `bug`

To deduplicate section, the priority should be `Bugs fixed > Performance > Features`

NOTE: This is a repo-specific script, so you may not use it in other places.
Release version from arguments will be used to map project name and branch name.
Mapping Pattern:
release -> project: Release {release}, branch: branch-{release}
e.g.
    0.1 -> project: Release 0.1, branch: branch-0.1
    0.2 -> project: Release 0.2, branch: branch-0.2

Dependencies:
    - requests

Github personal access token: https://github.com/settings/tokens, and make you have `repo` and `project` scope selected

Usage:
    cd spark-rapids/

    # generate changelog for releases X, Y, Z
    scripts/generate-changelog --token=<GITHUB_PERSONAL_ACCESS_TOKEN> \
    --releases=X,Y,Z

    # To a separate file like /tmp/CHANGELOG.md
    GITHUB_TOKEN=<GITHUB_PERSONAL_ACCESS_TOKEN> scripts/generate-changelog \
    --releases=X,Y,Z \
    --path=/tmp/CHANGELOG.md
"""
import os
import sys
from argparse import ArgumentParser
from collections import OrderedDict
from datetime import date, datetime
from dateutil.relativedelta import relativedelta

import requests

# Constants
RELEASE = "Release"
PULL_REQUESTS = "pullRequests"
ISSUES = "issues"
# Subtitles
INVALID = 'Invalid'
BUGS_FIXED = 'Bugs Fixed'
PERFORMANCE = 'Performance'
FEATURES = 'Features'
PRS = 'PRs'
# Labels
LABEL_WONTFIX, LABEL_INVALID, LABEL_DUPLICATE = 'wontfix', 'invalid', 'duplicate'
LABEL_BUG = 'bug'
LABEL_PERFORMANCE, LABEL_SHUFFLE = 'performance', 'shuffle'
LABEL_FEATURE, LABEL_SQL = 'feature request', 'SQL'
# Queries
query_pr = """
query ($baseRefName: String!, $after: String) {
  repository(name: "spark-rapids", owner: "NVIDIA") {
    pullRequests(states: [MERGED], baseRefName: $baseRefName, first: 100, after: $after) {
      totalCount
      nodes {
        number
        title
        headRefName
        baseRefName
        state
        url
        labels(first: 10) {
          nodes {
            name
          }
        }
        projectItems(first: 10) {
          nodes {
              roadmap: fieldValueByName(name: "Roadmap") {
                ... on ProjectV2ItemFieldSingleSelectValue {
                    name
                }
              }
          }
        }
        mergedAt
      }
      pageInfo {
        hasNextPage
        endCursor
      }
    }
  }
}
"""
query_issue = """
query ($after: String, $since: DateTime) {
  repository(name: "spark-rapids", owner: "NVIDIA") {
    issues(states: [CLOSED], labels: ["SQL", "feature request", "performance", "bug", "shuffle"], first: 100, after: $after, filterBy: {since: $since}) {
      totalCount
      nodes {
        number
        title
        state
        url
        labels(first: 10) {
          nodes {
            name
          }
        }
        projectItems(first: 10) {
          nodes {
            roadmap: fieldValueByName(name: "Roadmap") {
              ... on ProjectV2ItemFieldSingleSelectValue {
                name
              }
            }
          }
        }
        closedAt
      }
      pageInfo {
        hasNextPage
        endCursor
      }
    }
  }
}
"""


def process_changelog(resource_type: str, changelog: dict, releases: set, projects: set, token: str):
    if resource_type == PULL_REQUESTS:
        items = process_pr(releases=releases, token=token)
        time_field = 'mergedAt'
    elif resource_type == ISSUES:
        items = process_issue(releases=releases, token=token)
        time_field = 'closedAt'
    else:
        print(f"[process_changelog] Invalid type: {resource_type}")
        sys.exit(1)

    for item in items:
        if len(item["projectItems"]["nodes"]) == 0 or not item["projectItems"]["nodes"][0]['roadmap']:
            if resource_type == PULL_REQUESTS:
                if '[bot]' in item['title']:
                    continue  # skip auto-gen PR
                # Obtain the version from the PR's target branch, e.g. branch-x.y --> x.y
                ver = item['baseRefName'].replace('branch-', '')
                project = f"{RELEASE} {ver}"
            else:
                continue
        else:
            ver = item["projectItems"]["nodes"][0]['roadmap']['name']
            project = f"{RELEASE} {ver}"

        if not release_project(project, projects):
            continue

        if project not in changelog:
            changelog[project] = {
                FEATURES: [],
                PERFORMANCE: [],
                BUGS_FIXED: [],
                PRS: [],
            }

        labels = set()
        for label in item['labels']['nodes']:
            labels.add(label['name'])
        category = rules(labels)
        if resource_type == ISSUES and category == INVALID:
            continue
        if resource_type == PULL_REQUESTS:
            if '[bot]' in item['title']:  # skip auto-gen PR
                continue
            if '[databricks]' in item['title']: # strip ambiguous CI annotation
                item['title'] = item['title'].replace('[databricks]', '').strip()
            category = PRS

        changelog[project][category].append({
            "number": item['number'],
            "title": item['title'],
            "url": item['url'],
            "time": item[time_field],
        })


def process_pr(releases: set, token: str):
    pr = []
    for rel in releases:
        pr.extend(fetch(resource_type=PULL_REQUESTS, token=token,
                        variables={'baseRefName': f"branch-{rel}"}))
    return pr


def process_issue(releases: set, token: str):
    since = (datetime.utcnow() - relativedelta(months=3 * len(releases))).isoformat()
    return fetch(resource_type=ISSUES, token=token, variables={'since': since})


def fetch(resource_type: str, token: str, variables: dict = None):
    items = []
    if resource_type == PULL_REQUESTS and variables:
        q = query_pr
    elif resource_type == ISSUES:
        q = query_issue
    else:
        return items

    has_next = True
    while has_next:
        res = post(query=q, token=token, variable=variables)
        if res.status_code == 200:
            d = res.json()
            has_next = d['data']['repository'][resource_type]["pageInfo"]["hasNextPage"]
            variables['after'] = d['data']['repository'][resource_type]["pageInfo"]["endCursor"]
            items.extend(d['data']['repository'][resource_type]['nodes'])
        else:
            raise Exception("Query failed to run by returning code of {}. {}".format(res.status_code, q))
    return items


def post(query: str, token: str, variable: dict):
    return requests.post('https://api.github.com/graphql',
                         json={'query': query, 'variables': variable},
                         headers={"Authorization": f"token {token}"})


def release_project(project_name: str, projects: set):
    if project_name in projects:
        return True
    return False


def rules(labels: set):
    if LABEL_WONTFIX in labels or LABEL_INVALID in labels or LABEL_DUPLICATE in labels:
        return INVALID
    if LABEL_BUG in labels:
        return BUGS_FIXED
    if LABEL_PERFORMANCE in labels or LABEL_SHUFFLE in labels:
        return PERFORMANCE
    if LABEL_FEATURE in labels or LABEL_SQL in labels:
        return FEATURES
    return INVALID


def form_changelog(path: str, changelog: dict):
    sorted_dict = OrderedDict(sorted(changelog.items(), reverse=True))
    subsections = ""
    for project_name, issues in sorted_dict.items():
        subsections += f"\n\n## {project_name}"
        subsections += form_subsection(issues, FEATURES)
        subsections += form_subsection(issues, PERFORMANCE)
        subsections += form_subsection(issues, BUGS_FIXED)
        subsections += form_subsection(issues, PRS)
    markdown = f"""# Change log
Generated on {date.today()}{subsections}
\n## Older Releases
Changelog of older releases can be found at [docs/archives](/docs/archives)
"""
    with open(path, "w") as file:
        file.write(markdown)


def form_subsection(issues: dict, subtitle: str):
    if len(issues[subtitle]) == 0:
        return ''
    subsection = f"\n\n### {subtitle}"
    subsection += "\n|||\n|:---|:---|"
    for issue in sorted(issues[subtitle], key=lambda x: x['time'], reverse=True):
        subsection += f"\n|[#{issue['number']}]({issue['url']})|{issue['title']}|"
    return subsection



def main(rels: str, path: str, token: str):
    print('Generating changelog ...')

    try:
        changelog = {}  # changelog dict
        releases = {x.strip() for x in rels.split(',')}
        projects = {f"{RELEASE} {rel}" for rel in releases}

        print('Processing pull requests ...')
        process_changelog(resource_type=PULL_REQUESTS, changelog=changelog,
                          releases=releases, projects=projects, token=token)
        print('Processing issues ...')
        process_changelog(resource_type=ISSUES, changelog=changelog,
                          releases=releases, projects=projects, token=token)
        # form doc
        form_changelog(path=path, changelog=changelog)
    except Exception as e:  # pylint: disable=broad-except
        print(e)
        sys.exit(1)

    print('Done.')


if __name__ == '__main__':
    parser = ArgumentParser(description="Changelog Generator")
    parser.add_argument("--releases", help="list of release versions, separated by comma",
                        default="0.1,0.2,0.3")
    parser.add_argument("--token", help="github token, will use GITHUB_TOKEN if empty", default='')
    parser.add_argument("--path", help="path for generated changelog file", default='./CHANGELOG.md')
    args = parser.parse_args()

    github_token = args.token if args.token else os.environ.get('GITHUB_TOKEN')
    assert github_token, 'env GITHUB_TOKEN should not be empty'

    main(rels=args.releases, path=args.path, token=github_token)
